#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015-2020 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import argparse
import datetime
import backtrader as bt
# from backtrade.CoinCsv import CoinCsv
import backtrader.indicators as btind
import numpy as np


class CoinCsv(bt.feeds.GenericCSVData):
    '''
    Parses a `Metatrader4 <https://www.metaquotes.net/en/metatrader4>`_ History
    center CSV exported file.

    Specific parameters (or specific meaning):

      - ``dataname``: The filename to parse or a file-like object

      - Uses GenericCSVData and simply modifies the params
    '''
    # 把OHLC的数据结构放弃掉, 自己自定义
    linesoverride = True

    # 额外加的列数据
    lines = ('datetime', 'open', 'high', 'low', 'close', 'volume', 'amount', 'count', )

    DateTime, Open, High, Low, Close, Volume, Amount, Count = range(8)

    LineOrder = [DateTime, Open, High, Low, Close, Volume, Amount, Count]

    # 读取数据的配置, 对应csv的哪一列, 索引从0开始
    params = (
        ('nullvalue', float('NaN')),
        ('dtformat', '%Y-%m-%d %H:%M'),
        ('tmformat', '%H:%M'),
        ('datetime', 0),
        #('time', -1),
        ('open', 1),
        ('high', 2),
        ('low', 3),
        ('close', 4),
        ('volume', 5),
        ('amount', 7),
        ('count', 8),
    )


class CrossOverSt(bt.Strategy):

    alias = ('CrossOverSt',)

    params = (
        # period for the fast Moving Average
        ('fast', 7),
        # period for the slow moving average
        ('slow', 14),
    )

    def log(self, txt, dt=None):
        ''' Logging function fot this strategy'''
        print('%s, %s' % (bt.utils.num2date(dt).strftime("%Y-%m-%d"), txt))

    def __init__(self):
        # 这里的指标会自己通过owner注册到策略之上
        # 在每个__init__对象中, 对应的指标都会被注册在woner上的_lineiterator[]集合中
        ema_fast = btind.EMA(self.data.close, period=self.p.fast)
        ema_slow = btind.EMA(self.data.close, period=self.p.slow)
        self.buysig = btind.CrossOver(ema_fast, ema_slow)

    def notify_cashvalue(self, cash, value):
        dtstr = self.data.datetime.datetime().strftime("%Y-%m-%d")
        # print(f"date: {dtstr}, cash: {cash}, value: {value}")

    def notify_trade(self, trade):
        if trade.status == trade.Created:
            self.log('TRADE CREATE, price: %.2f size: %.2f pnl: %.2f' % (trade.price, trade.size, trade.pnl), dt=trade.tradeid)
        elif trade.status == trade.Open:
            self.log('TRADE OPEN, price: %.2f size: %.2f pnl: %.2f' % (trade.price, trade.size, trade.pnl), dt=trade.dtopen)
        elif trade.status == trade.Closed:
            self.log('TRADE CLOSED, price: %.2f size: %.2f pnl: %.2f' % (trade.price, trade.size, trade.pnl), dt=trade.dtclose)

    def notify_order(self, order):
        dtstr = self.data.datetime.datetime().strftime("%Y-%m-%d")
        if order.status == order.Submitted:
            self.log('ORDER SUBMITTED', dt=order.created.dt)
            self.order = order
            return
        if order.status == order.Accepted:
            # Buy/Sell order submitted/accepted to/by broker - Nothing to do
            self.log('ORDER ACCEPTED', dt=order.created.dt)
            self.order = order
            return

        if order.status in [order.Expired]:
            self.log('ORDER BUY EXPIRED', dt=order.created.dt)

        elif order.status in [order.Completed]:
            if order.isbuy():
                self.log(
                    'ORDER BUY COMPLETED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                    (order.executed.price,
                     order.executed.value,
                     order.executed.comm), dt=order.executed.dt)

            else:  # Sell
                self.log('ORDER SELL COMPLETED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                         (order.executed.price,
                          order.executed.value,
                          order.executed.comm), dt=order.executed.dt)

        # Sentinel to None: new orders allowed
        self.order = None

    def next(self):
        """
        调用next之前, 是会先把相关的指标算好, 然后调用notify_order，再调用当前新的next, 在处理这个策略逻辑时候必须保证之前的单子要处理完成
        所有notify_order() 里面的order.created.dt的时间和self.data.datetime时间不一样
        Returns:

        """
        dtstr = self.data.datetime.datetime().strftime("%Y-%m-%d")
        print(f"{dtstr}")
        if self.position.size:
            if self.buysig < 0:
                print(f"**************Sell: {dtstr}")
                self.sell()

        elif self.buysig > 0:
            print(f"***************Buy: {dtstr}")
            self.buy()


def runstrat(pargs=None):
    args = dict(
        cash=100000,
        fromdate="2021-01-18",
        todate="2021-04-18",
        dataname="/Users/wudi/Workspace/jupyter-analysis/data/binance/market/1d/BTCUSDT.csv",
        # dataname="../datas/orcl-1995-2014.txt",
        tframe="days",
        # datetime如果大于等于0, 说明是合并time时间, time是时间的索引号
        # time=-1,
        # Riskfree Rate (annual) for Sharpe
        riskfreerate=0.0,
        # Riskfree Rate conversion factor for Sharpe to downgrade riskfree rate to timeframe
        factor=0.0,
        writercsv=False,
        plot=True,
    )

    tframes = dict(
        days=bt.TimeFrame.Days,
        weeks=bt.TimeFrame.Weeks,
        months=bt.TimeFrame.Months,
        years=bt.TimeFrame.Years,
        notimeframe=bt.TimeFrame.NoTimeFrame,
    )

    # Create a cerebro
    cerebro = bt.Cerebro()

    cerebro.broker.set_cash(args["cash"])
    cerebro.broker.set_coc(True)

    # Get the dates from the args
    fromdate = datetime.datetime.strptime(args["fromdate"], '%Y-%m-%d')
    todate = datetime.datetime.strptime(args["todate"], '%Y-%m-%d')

    # Create the 1st data

    data = CoinCsv(
        dataname=args["dataname"],
        fromdate=fromdate,
        todate=todate
    )

    cerebro.adddata(data)  # Add the data to cerebro

    # Add the strategy
    cerebro.addstrategy(CrossOverSt)

    cerebro.addsizer(bt.sizers.PercentSizerInt, percents=90)

    # Add the Analyzers
    # cerebro.addanalyzer(bt.analyzers.TimeReturn,
    #                     timeframe=tframes[args["tframe"]])

    cerebro.addobserver(bt.observers.DrawDown)
    cerebro.addobserver(bt.observers.DrawDownLength)

    shkwargs = dict()
    shkwargs['annualize'] = True
    shkwargs['riskfreerate'] = args["riskfreerate"]
    shkwargs['factor'] = args["factor"]
    shkwargs['stddev_sample'] = True
    shkwargs['convertrate'] = False

    cerebro.addanalyzer(bt.analyzers.SharpeRatio,
                        timeframe=tframes[args["tframe"]],
                        **shkwargs)

    # Add a writer to get output
    cerebro.addwriter(bt.WriterFile, csv=True, out="test.csv", rounding=4)

    cerebro.run(runonce=False)  # And run it

    # Plot if requested

    pkwargs = dict(
        style='candle',
        volume=True
    )

    cerebro.plot(**pkwargs)


if __name__ == '__main__':
    runstrat()
